Deep Dive into PyTorch Internals: Autograd and torch.fx

Pytorch
Deep Learning
Author

Ismail TG

Published

April 5, 2025

Deep Dive into PyTorch Internals: Autograd and torch.fx

Introduction

As I continue to deepen my understanding of machine learning systems, I’ve realized that knowing how models run is just as important as knowing how to build them. This post kicks off a series where I explore PyTorch internals, starting with two powerful components: Autograd and torch.fx.

PyTorch Autograd – The Engine Behind Gradient Descent

What Is Autograd?

PyTorch’s autograd is a dynamic automatic differentiation engine. It records operations on tensors to build a computation graph during the forward pass, and then traverses that graph in reverse to compute gradients during the backward pass.

How It Works Internally

When you perform operations on torch.Tensor objects with requires_grad=True, PyTorch:

  1. Creates a computation graph on the fly.
  2. Each operation produces a Function object (e.g., AddBackward, MulBackward).
  3. When .backward() is called, the engine performs reverse-mode automatic differentiation.

Example: Simple Chain Rule

import torch
x = torch.tensor(2.0, requires_grad=True)
y = x * x + 3 * x
z = y.mean()
z.backward()
print(x.grad)  # 7.0 = d(x^2 + 3x)/dx at x=2
tensor(7.)

Key Internals

  • Tensor.grad_fn: Points to the function that created the tensor.
  • Tensor.grad: Stores the computed gradient.
  • torch.autograd.Function: Base class for custom differentiable operations.

torch.fx – PyTorch’s Intermediate Representation

Why Use torch.fx?

torch.fx allows you to capture and transform PyTorch programs as Python-level graphs. This is useful for: - Programmatic model transformations - Debugging and visualization - Building custom compiler backends

Core Components

  • GraphModule: A traced model with a modifiable structure.
  • Tracer: Walks through the model and builds a Graph.
  • Graph: Contains Node objects that represent operations.

import torch
import torch.nn as nn
import torch.fx as fx

class MyModel(nn.Module):
    def forward(self, x):
        return x * 2 + 3

model = MyModel()
traced = fx.symbolic_trace(model)
print(traced.graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%mul, 3), kwargs = {})
    return add

Real-World Use Cases

  • TorchDynamo and TorchInductor use FX graphs as part of the compilation pipeline.
  • FX enables quantization and pruning workflows by allowing insertion or transformation of operations.

Conclusion

Both Autograd and torch.fx are essential for understanding what happens under the hood in PyTorch. Whether you’re debugging models, optimizing inference, or building custom backends, mastering these tools opens the door to deeper systems-level work in AI.

In future posts, I plan to explore: - Implementing custom autograd functions - Writing your own FX passes for transformations - Diving into TorchDynamo and TorchInductor